{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1438d967-9e38-4a54-8c2b-2d28ff7ad305",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,shutil\n",
    "import random"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8d876c7-85f5-4ac8-8c2f-be92210cbf57",
   "metadata": {},
   "source": [
    "[kaggle 猫狗数据集下载](https://www.microsoft.com/en-us/download/details.aspx?id=54765)  \n",
    "\n",
    "一共有共25000张图片"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06586bf4-c2bf-4dd9-b1ff-bb1e7130a646",
   "metadata": {},
   "source": [
    "### 通用数据集分割流程\n",
    "\n",
    "原始数据集结构：\n",
    " - PetImages\n",
    "   - Cat\n",
    "   - Dog\n",
    "\n",
    "处理之后的数据集结构：\n",
    "- PetImages\n",
    "  - train\n",
    "    - Cat\n",
    "    - Dog\n",
    "  - test\n",
    "    - Cat\n",
    "    - Dog\n",
    "\n",
    "代码具有普适性，可直接用于处理类似结构的数据集，只需要修改`root_dir`, 并根据需要修改`test_rate`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c6821185-1711-4967-860b-5ff16ca715dc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "数据集处理完成\n"
     ]
    }
   ],
   "source": [
    "root_dir=r\"../datasets/PetImages\"\n",
    "categories = os.listdir(root_dir)\n",
    "\n",
    "train_dir = os.path.join(root_dir, 'train')\n",
    "os.makedirs(train_dir, exist_ok=True)\n",
    "test_dir = os.path.join(root_dir, 'test')\n",
    "os.makedirs(test_dir, exist_ok=True)\n",
    "\n",
    "test_rate=0.1 #训练集和测试集的比例为9:1 22500 : 2500， 可以自己指定比例\n",
    "\n",
    "for category in categories:\n",
    "    src_dir = os.path.join(root_dir, category)\n",
    "    filenames = os.listdir(src_dir)\n",
    "    test_num = int(len(filenames) * test_rate)\n",
    "    test_filenames = random.sample(filenames, test_num)\n",
    "    \n",
    "    # 移动测试图片\n",
    "    test_category_dir = os.path.join(test_dir, category)\n",
    "    os.makedirs(test_category_dir, exist_ok=True)\n",
    "    for test_filename in test_filenames:\n",
    "        src_path = os.path.join(src_dir, test_filename)\n",
    "        tgt_path = os.path.join(test_category_dir, test_filename)\n",
    "        shutil.move(src_path, tgt_path)\n",
    "        \n",
    "    # 移动训练集图片(src_dir中剩下的图片)\n",
    "    train_category_dir = os.path.join(train_dir, category)\n",
    "    os.makedirs(train_category_dir, exist_ok=True)\n",
    "    for train_filename in os.listdir(src_dir):\n",
    "        src_path = os.path.join(src_dir, train_filename)\n",
    "        tgt_path = os.path.join(train_category_dir, train_filename)\n",
    "        shutil.move(src_path, tgt_path)\n",
    "\n",
    "    # 删除原始目录\n",
    "    os.rmdir(src_dir)\n",
    "    \n",
    "print(\"数据集处理完成\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84bb2b5f-3651-418c-b390-d40e2ff28a2f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
